1 //===- TestFunctionLike.cpp - Pass to test helpers on FunctionLike --------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/IR/Function.h"
10 #include "mlir/Pass/Pass.h"
11 
12 using namespace mlir;
13 
14 namespace {
15 /// This is a test pass for verifying FuncOp's eraseArgument method.
16 struct TestFuncEraseArg
17     : public PassWrapper<TestFuncEraseArg, OperationPass<ModuleOp>> {
runOnOperation__anon97024c9c0111::TestFuncEraseArg18   void runOnOperation() override {
19     auto module = getOperation();
20 
21     for (FuncOp func : module.getOps<FuncOp>()) {
22       SmallVector<unsigned, 4> indicesToErase;
23       for (auto argIndex : llvm::seq<int>(0, func.getNumArguments())) {
24         if (func.getArgAttr(argIndex, "test.erase_this_arg")) {
25           // Push back twice to test that duplicate arg indices are handled
26           // correctly.
27           indicesToErase.push_back(argIndex);
28           indicesToErase.push_back(argIndex);
29         }
30       }
31       // Reverse the order to test that unsorted index lists are handled
32       // correctly.
33       std::reverse(indicesToErase.begin(), indicesToErase.end());
34       func.eraseArguments(indicesToErase);
35     }
36   }
37 };
38 
39 /// This is a test pass for verifying FuncOp's setType method.
40 struct TestFuncSetType
41     : public PassWrapper<TestFuncSetType, OperationPass<ModuleOp>> {
runOnOperation__anon97024c9c0111::TestFuncSetType42   void runOnOperation() override {
43     auto module = getOperation();
44     SymbolTable symbolTable(module);
45 
46     for (FuncOp func : module.getOps<FuncOp>()) {
47       auto sym = func.getAttrOfType<FlatSymbolRefAttr>("test.set_type_from");
48       if (!sym)
49         continue;
50       func.setType(symbolTable.lookup<FuncOp>(sym.getValue()).getType());
51     }
52   }
53 };
54 } // end anonymous namespace
55 
56 namespace mlir {
registerTestFunc()57 void registerTestFunc() {
58   PassRegistration<TestFuncEraseArg> pass("test-func-erase-arg",
59                                           "Test erasing func args.");
60 
61   PassRegistration<TestFuncSetType> pass2("test-func-set-type",
62                                           "Test FuncOp::setType.");
63 }
64 } // namespace mlir
65