1 //===- PassManagerTest.cpp - PassManager unit tests -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Pass/PassManager.h"
10 #include "mlir/IR/Builders.h"
11 #include "mlir/IR/BuiltinOps.h"
12 #include "mlir/Pass/Pass.h"
13 #include "gtest/gtest.h"
14 
15 using namespace mlir;
16 using namespace mlir::detail;
17 
18 namespace {
19 /// Analysis that operates on any operation.
20 struct GenericAnalysis {
GenericAnalysis__anon6b73f1430111::GenericAnalysis21   GenericAnalysis(Operation *op) : isFunc(isa<FuncOp>(op)) {}
22   const bool isFunc;
23 };
24 
25 /// Analysis that operates on a specific operation.
26 struct OpSpecificAnalysis {
OpSpecificAnalysis__anon6b73f1430111::OpSpecificAnalysis27   OpSpecificAnalysis(FuncOp op) : isSecret(op.getName() == "secret") {}
28   const bool isSecret;
29 };
30 
31 /// Simple pass to annotate a FuncOp with the results of analysis.
32 /// Note: not using FunctionPass as it skip external functions.
33 struct AnnotateFunctionPass
34     : public PassWrapper<AnnotateFunctionPass, OperationPass<FuncOp>> {
runOnOperation__anon6b73f1430111::AnnotateFunctionPass35   void runOnOperation() override {
36     FuncOp op = getOperation();
37     Builder builder(op->getParentOfType<ModuleOp>());
38 
39     auto &ga = getAnalysis<GenericAnalysis>();
40     auto &sa = getAnalysis<OpSpecificAnalysis>();
41 
42     op->setAttr("isFunc", builder.getBoolAttr(ga.isFunc));
43     op->setAttr("isSecret", builder.getBoolAttr(sa.isSecret));
44   }
45 };
46 
TEST(PassManagerTest,OpSpecificAnalysis)47 TEST(PassManagerTest, OpSpecificAnalysis) {
48   MLIRContext context;
49   Builder builder(&context);
50 
51   // Create a module with 2 functions.
52   OwningModuleRef module(ModuleOp::create(UnknownLoc::get(&context)));
53   for (StringRef name : {"secret", "not_secret"}) {
54     FuncOp func =
55         FuncOp::create(builder.getUnknownLoc(), name,
56                        builder.getFunctionType(llvm::None, llvm::None));
57     func.setPrivate();
58     module->push_back(func);
59   }
60 
61   // Instantiate and run our pass.
62   PassManager pm(&context);
63   pm.addNestedPass<FuncOp>(std::make_unique<AnnotateFunctionPass>());
64   LogicalResult result = pm.run(module.get());
65   EXPECT_TRUE(succeeded(result));
66 
67   // Verify that each function got annotated with expected attributes.
68   for (FuncOp func : module->getOps<FuncOp>()) {
69     ASSERT_TRUE(func->getAttr("isFunc").isa<BoolAttr>());
70     EXPECT_TRUE(func->getAttr("isFunc").cast<BoolAttr>().getValue());
71 
72     bool isSecret = func.getName() == "secret";
73     ASSERT_TRUE(func->getAttr("isSecret").isa<BoolAttr>());
74     EXPECT_EQ(func->getAttr("isSecret").cast<BoolAttr>().getValue(), isSecret);
75   }
76 }
77 
78 namespace {
79 struct InvalidPass : Pass {
InvalidPass__anon6b73f1430111::__anon6b73f1430211::InvalidPass80   InvalidPass() : Pass(TypeID::get<InvalidPass>(), StringRef("invalid_op")) {}
getName__anon6b73f1430111::__anon6b73f1430211::InvalidPass81   StringRef getName() const override { return "Invalid Pass"; }
runOnOperation__anon6b73f1430111::__anon6b73f1430211::InvalidPass82   void runOnOperation() override {}
83 
84   /// A clone method to create a copy of this pass.
clonePass__anon6b73f1430111::__anon6b73f1430211::InvalidPass85   std::unique_ptr<Pass> clonePass() const override {
86     return std::make_unique<InvalidPass>(
87         *static_cast<const InvalidPass *>(this));
88   }
89 };
90 } // anonymous namespace
91 
TEST(PassManagerTest,InvalidPass)92 TEST(PassManagerTest, InvalidPass) {
93   MLIRContext context;
94   context.allowUnregisteredDialects();
95 
96   // Create a module
97   OwningModuleRef module(ModuleOp::create(UnknownLoc::get(&context)));
98 
99   // Add a single "invalid_op" operation
100   OpBuilder builder(&module->getBodyRegion());
101   OperationState state(UnknownLoc::get(&context), "invalid_op");
102   builder.insert(Operation::create(state));
103 
104   // Register a diagnostic handler to capture the diagnostic so that we can
105   // check it later.
106   std::unique_ptr<Diagnostic> diagnostic;
107   context.getDiagEngine().registerHandler([&](Diagnostic &diag) {
108     diagnostic.reset(new Diagnostic(std::move(diag)));
109   });
110 
111   // Instantiate and run our pass.
112   PassManager pm(&context);
113   pm.nest("invalid_op").addPass(std::make_unique<InvalidPass>());
114   LogicalResult result = pm.run(module.get());
115   EXPECT_TRUE(failed(result));
116   ASSERT_TRUE(diagnostic.get() != nullptr);
117   EXPECT_EQ(
118       diagnostic->str(),
119       "'invalid_op' op trying to schedule a pass on an unregistered operation");
120 
121   // Check that adding the pass at the top-level triggers a fatal error.
122   ASSERT_DEATH(pm.addPass(std::make_unique<InvalidPass>()), "");
123 }
124 
125 } // end namespace
126