1 //===- PassManagerTest.cpp - PassManager unit tests -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Pass/PassManager.h"
10 #include "mlir/IR/Builders.h"
11 #include "mlir/IR/Function.h"
12 #include "mlir/Pass/Pass.h"
13 #include "gtest/gtest.h"
14 
15 using namespace mlir;
16 using namespace mlir::detail;
17 
18 namespace {
19 /// Analysis that operates on any operation.
20 struct GenericAnalysis {
GenericAnalysis__anon097f9e560111::GenericAnalysis21   GenericAnalysis(Operation *op) : isFunc(isa<FuncOp>(op)) {}
22   const bool isFunc;
23 };
24 
25 /// Analysis that operates on a specific operation.
26 struct OpSpecificAnalysis {
OpSpecificAnalysis__anon097f9e560111::OpSpecificAnalysis27   OpSpecificAnalysis(FuncOp op) : isSecret(op.getName() == "secret") {}
28   const bool isSecret;
29 };
30 
31 /// Simple pass to annotate a FuncOp with the results of analysis.
32 /// Note: not using FunctionPass as it skip external functions.
33 struct AnnotateFunctionPass
34     : public PassWrapper<AnnotateFunctionPass, OperationPass<FuncOp>> {
runOnOperation__anon097f9e560111::AnnotateFunctionPass35   void runOnOperation() override {
36     FuncOp op = getOperation();
37     Builder builder(op.getParentOfType<ModuleOp>());
38 
39     auto &ga = getAnalysis<GenericAnalysis>();
40     auto &sa = getAnalysis<OpSpecificAnalysis>();
41 
42     op.setAttr("isFunc", builder.getBoolAttr(ga.isFunc));
43     op.setAttr("isSecret", builder.getBoolAttr(sa.isSecret));
44   }
45 };
46 
TEST(PassManagerTest,OpSpecificAnalysis)47 TEST(PassManagerTest, OpSpecificAnalysis) {
48   MLIRContext context;
49   Builder builder(&context);
50 
51   // Create a module with 2 functions.
52   OwningModuleRef module(ModuleOp::create(UnknownLoc::get(&context)));
53   for (StringRef name : {"secret", "not_secret"}) {
54     FuncOp func =
55         FuncOp::create(builder.getUnknownLoc(), name,
56                        builder.getFunctionType(llvm::None, llvm::None));
57     module->push_back(func);
58   }
59 
60   // Instantiate and run our pass.
61   PassManager pm(&context);
62   pm.addNestedPass<FuncOp>(std::make_unique<AnnotateFunctionPass>());
63   LogicalResult result = pm.run(module.get());
64   EXPECT_TRUE(succeeded(result));
65 
66   // Verify that each function got annotated with expected attributes.
67   for (FuncOp func : module->getOps<FuncOp>()) {
68     ASSERT_TRUE(func.getAttr("isFunc").isa<BoolAttr>());
69     EXPECT_TRUE(func.getAttr("isFunc").cast<BoolAttr>().getValue());
70 
71     bool isSecret = func.getName() == "secret";
72     ASSERT_TRUE(func.getAttr("isSecret").isa<BoolAttr>());
73     EXPECT_EQ(func.getAttr("isSecret").cast<BoolAttr>().getValue(), isSecret);
74   }
75 }
76 
77 namespace {
78 struct InvalidPass : Pass {
InvalidPass__anon097f9e560111::__anon097f9e560211::InvalidPass79   InvalidPass() : Pass(TypeID::get<InvalidPass>(), StringRef("invalid_op")) {}
getName__anon097f9e560111::__anon097f9e560211::InvalidPass80   StringRef getName() const override { return "Invalid Pass"; }
runOnOperation__anon097f9e560111::__anon097f9e560211::InvalidPass81   void runOnOperation() override {}
82 
83   /// A clone method to create a copy of this pass.
clonePass__anon097f9e560111::__anon097f9e560211::InvalidPass84   std::unique_ptr<Pass> clonePass() const override {
85     return std::make_unique<InvalidPass>(
86         *static_cast<const InvalidPass *>(this));
87   }
88 };
89 } // anonymous namespace
90 
TEST(PassManagerTest,InvalidPass)91 TEST(PassManagerTest, InvalidPass) {
92   MLIRContext context;
93 
94   // Create a module
95   OwningModuleRef module(ModuleOp::create(UnknownLoc::get(&context)));
96 
97   // Add a single "invalid_op" operation
98   OpBuilder builder(&module->getBodyRegion());
99   OperationState state(UnknownLoc::get(&context), "invalid_op");
100   builder.insert(Operation::create(state));
101 
102   // Register a diagnostic handler to capture the diagnostic so that we can
103   // check it later.
104   std::unique_ptr<Diagnostic> diagnostic;
105   context.getDiagEngine().registerHandler([&](Diagnostic &diag) {
106     diagnostic.reset(new Diagnostic(std::move(diag)));
107   });
108 
109   // Instantiate and run our pass.
110   PassManager pm(&context);
111   pm.nest("invalid_op").addPass(std::make_unique<InvalidPass>());
112   LogicalResult result = pm.run(module.get());
113   EXPECT_TRUE(failed(result));
114   ASSERT_TRUE(diagnostic.get() != nullptr);
115   EXPECT_EQ(
116       diagnostic->str(),
117       "'invalid_op' op trying to schedule a pass on an unregistered operation");
118 
119   // Check that adding the pass at the top-level triggers a fatal error.
120   ASSERT_DEATH(pm.addPass(std::make_unique<InvalidPass>()), "");
121 }
122 
123 } // end namespace
124