1 //===- TestFunctionLike.cpp - Pass to test helpers on FunctionLike --------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/IR/BuiltinOps.h"
10 #include "mlir/Pass/Pass.h"
11 
12 using namespace mlir;
13 
14 namespace {
15 /// This is a test pass for verifying FuncOp's insertArgument method.
16 struct TestFuncInsertArg
17     : public PassWrapper<TestFuncInsertArg, OperationPass<ModuleOp>> {
getArgument__anon70e73dd40111::TestFuncInsertArg18   StringRef getArgument() const final { return "test-func-insert-arg"; }
getDescription__anon70e73dd40111::TestFuncInsertArg19   StringRef getDescription() const final { return "Test inserting func args."; }
runOnOperation__anon70e73dd40111::TestFuncInsertArg20   void runOnOperation() override {
21     auto module = getOperation();
22 
23     for (FuncOp func : module.getOps<FuncOp>()) {
24       auto inserts = func->getAttrOfType<ArrayAttr>("test.insert_args");
25       if (!inserts || inserts.empty())
26         continue;
27       SmallVector<unsigned, 4> indicesToInsert;
28       SmallVector<Type, 4> typesToInsert;
29       SmallVector<DictionaryAttr, 4> attrsToInsert;
30       SmallVector<Optional<Location>, 4> locsToInsert;
31       for (auto insert : inserts.getAsRange<ArrayAttr>()) {
32         indicesToInsert.push_back(
33             insert[0].cast<IntegerAttr>().getValue().getZExtValue());
34         typesToInsert.push_back(insert[1].cast<TypeAttr>().getValue());
35         attrsToInsert.push_back(insert.size() > 2
36                                     ? insert[2].cast<DictionaryAttr>()
37                                     : DictionaryAttr::get(&getContext()));
38         locsToInsert.push_back(
39             insert.size() > 3
40                 ? Optional<Location>(insert[3].cast<LocationAttr>())
41                 : Optional<Location>{});
42       }
43       func->removeAttr("test.insert_args");
44       func.insertArguments(indicesToInsert, typesToInsert, attrsToInsert,
45                            locsToInsert);
46     }
47   }
48 };
49 
50 /// This is a test pass for verifying FuncOp's insertResult method.
51 struct TestFuncInsertResult
52     : public PassWrapper<TestFuncInsertResult, OperationPass<ModuleOp>> {
getArgument__anon70e73dd40111::TestFuncInsertResult53   StringRef getArgument() const final { return "test-func-insert-result"; }
getDescription__anon70e73dd40111::TestFuncInsertResult54   StringRef getDescription() const final {
55     return "Test inserting func results.";
56   }
runOnOperation__anon70e73dd40111::TestFuncInsertResult57   void runOnOperation() override {
58     auto module = getOperation();
59 
60     for (FuncOp func : module.getOps<FuncOp>()) {
61       auto inserts = func->getAttrOfType<ArrayAttr>("test.insert_results");
62       if (!inserts || inserts.empty())
63         continue;
64       SmallVector<unsigned, 4> indicesToInsert;
65       SmallVector<Type, 4> typesToInsert;
66       SmallVector<DictionaryAttr, 4> attrsToInsert;
67       for (auto insert : inserts.getAsRange<ArrayAttr>()) {
68         indicesToInsert.push_back(
69             insert[0].cast<IntegerAttr>().getValue().getZExtValue());
70         typesToInsert.push_back(insert[1].cast<TypeAttr>().getValue());
71         attrsToInsert.push_back(insert.size() > 2
72                                     ? insert[2].cast<DictionaryAttr>()
73                                     : DictionaryAttr::get(&getContext()));
74       }
75       func->removeAttr("test.insert_results");
76       func.insertResults(indicesToInsert, typesToInsert, attrsToInsert);
77     }
78   }
79 };
80 
81 /// This is a test pass for verifying FuncOp's eraseArgument method.
82 struct TestFuncEraseArg
83     : public PassWrapper<TestFuncEraseArg, OperationPass<ModuleOp>> {
getArgument__anon70e73dd40111::TestFuncEraseArg84   StringRef getArgument() const final { return "test-func-erase-arg"; }
getDescription__anon70e73dd40111::TestFuncEraseArg85   StringRef getDescription() const final { return "Test erasing func args."; }
runOnOperation__anon70e73dd40111::TestFuncEraseArg86   void runOnOperation() override {
87     auto module = getOperation();
88 
89     for (FuncOp func : module.getOps<FuncOp>()) {
90       SmallVector<unsigned, 4> indicesToErase;
91       for (auto argIndex : llvm::seq<int>(0, func.getNumArguments())) {
92         if (func.getArgAttr(argIndex, "test.erase_this_arg")) {
93           // Push back twice to test that duplicate arg indices are handled
94           // correctly.
95           indicesToErase.push_back(argIndex);
96           indicesToErase.push_back(argIndex);
97         }
98       }
99       // Reverse the order to test that unsorted index lists are handled
100       // correctly.
101       std::reverse(indicesToErase.begin(), indicesToErase.end());
102       func.eraseArguments(indicesToErase);
103     }
104   }
105 };
106 
107 /// This is a test pass for verifying FuncOp's eraseResult method.
108 struct TestFuncEraseResult
109     : public PassWrapper<TestFuncEraseResult, OperationPass<ModuleOp>> {
getArgument__anon70e73dd40111::TestFuncEraseResult110   StringRef getArgument() const final { return "test-func-erase-result"; }
getDescription__anon70e73dd40111::TestFuncEraseResult111   StringRef getDescription() const final {
112     return "Test erasing func results.";
113   }
runOnOperation__anon70e73dd40111::TestFuncEraseResult114   void runOnOperation() override {
115     auto module = getOperation();
116 
117     for (FuncOp func : module.getOps<FuncOp>()) {
118       SmallVector<unsigned, 4> indicesToErase;
119       for (auto resultIndex : llvm::seq<int>(0, func.getNumResults())) {
120         if (func.getResultAttr(resultIndex, "test.erase_this_result")) {
121           // Push back twice to test that duplicate indices are handled
122           // correctly.
123           indicesToErase.push_back(resultIndex);
124           indicesToErase.push_back(resultIndex);
125         }
126       }
127       // Reverse the order to test that unsorted index lists are handled
128       // correctly.
129       std::reverse(indicesToErase.begin(), indicesToErase.end());
130       func.eraseResults(indicesToErase);
131     }
132   }
133 };
134 
135 /// This is a test pass for verifying FuncOp's setType method.
136 struct TestFuncSetType
137     : public PassWrapper<TestFuncSetType, OperationPass<ModuleOp>> {
getArgument__anon70e73dd40111::TestFuncSetType138   StringRef getArgument() const final { return "test-func-set-type"; }
getDescription__anon70e73dd40111::TestFuncSetType139   StringRef getDescription() const final { return "Test FuncOp::setType."; }
runOnOperation__anon70e73dd40111::TestFuncSetType140   void runOnOperation() override {
141     auto module = getOperation();
142     SymbolTable symbolTable(module);
143 
144     for (FuncOp func : module.getOps<FuncOp>()) {
145       auto sym = func->getAttrOfType<FlatSymbolRefAttr>("test.set_type_from");
146       if (!sym)
147         continue;
148       func.setType(symbolTable.lookup<FuncOp>(sym.getValue()).getType());
149     }
150   }
151 };
152 } // end anonymous namespace
153 
154 namespace mlir {
registerTestFunc()155 void registerTestFunc() {
156   PassRegistration<TestFuncInsertArg>();
157 
158   PassRegistration<TestFuncInsertResult>();
159 
160   PassRegistration<TestFuncEraseArg>();
161 
162   PassRegistration<TestFuncEraseResult>();
163 
164   PassRegistration<TestFuncSetType>();
165 }
166 } // namespace mlir
167