1 //===- TestFunctionLike.cpp - Pass to test helpers on FunctionLike --------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/IR/BuiltinOps.h"
10 #include "mlir/Pass/Pass.h"
11
12 using namespace mlir;
13
14 namespace {
15 /// This is a test pass for verifying FuncOp's insertArgument method.
16 struct TestFuncInsertArg
17 : public PassWrapper<TestFuncInsertArg, OperationPass<ModuleOp>> {
getArgument__anon70e73dd40111::TestFuncInsertArg18 StringRef getArgument() const final { return "test-func-insert-arg"; }
getDescription__anon70e73dd40111::TestFuncInsertArg19 StringRef getDescription() const final { return "Test inserting func args."; }
runOnOperation__anon70e73dd40111::TestFuncInsertArg20 void runOnOperation() override {
21 auto module = getOperation();
22
23 for (FuncOp func : module.getOps<FuncOp>()) {
24 auto inserts = func->getAttrOfType<ArrayAttr>("test.insert_args");
25 if (!inserts || inserts.empty())
26 continue;
27 SmallVector<unsigned, 4> indicesToInsert;
28 SmallVector<Type, 4> typesToInsert;
29 SmallVector<DictionaryAttr, 4> attrsToInsert;
30 SmallVector<Optional<Location>, 4> locsToInsert;
31 for (auto insert : inserts.getAsRange<ArrayAttr>()) {
32 indicesToInsert.push_back(
33 insert[0].cast<IntegerAttr>().getValue().getZExtValue());
34 typesToInsert.push_back(insert[1].cast<TypeAttr>().getValue());
35 attrsToInsert.push_back(insert.size() > 2
36 ? insert[2].cast<DictionaryAttr>()
37 : DictionaryAttr::get(&getContext()));
38 locsToInsert.push_back(
39 insert.size() > 3
40 ? Optional<Location>(insert[3].cast<LocationAttr>())
41 : Optional<Location>{});
42 }
43 func->removeAttr("test.insert_args");
44 func.insertArguments(indicesToInsert, typesToInsert, attrsToInsert,
45 locsToInsert);
46 }
47 }
48 };
49
50 /// This is a test pass for verifying FuncOp's insertResult method.
51 struct TestFuncInsertResult
52 : public PassWrapper<TestFuncInsertResult, OperationPass<ModuleOp>> {
getArgument__anon70e73dd40111::TestFuncInsertResult53 StringRef getArgument() const final { return "test-func-insert-result"; }
getDescription__anon70e73dd40111::TestFuncInsertResult54 StringRef getDescription() const final {
55 return "Test inserting func results.";
56 }
runOnOperation__anon70e73dd40111::TestFuncInsertResult57 void runOnOperation() override {
58 auto module = getOperation();
59
60 for (FuncOp func : module.getOps<FuncOp>()) {
61 auto inserts = func->getAttrOfType<ArrayAttr>("test.insert_results");
62 if (!inserts || inserts.empty())
63 continue;
64 SmallVector<unsigned, 4> indicesToInsert;
65 SmallVector<Type, 4> typesToInsert;
66 SmallVector<DictionaryAttr, 4> attrsToInsert;
67 for (auto insert : inserts.getAsRange<ArrayAttr>()) {
68 indicesToInsert.push_back(
69 insert[0].cast<IntegerAttr>().getValue().getZExtValue());
70 typesToInsert.push_back(insert[1].cast<TypeAttr>().getValue());
71 attrsToInsert.push_back(insert.size() > 2
72 ? insert[2].cast<DictionaryAttr>()
73 : DictionaryAttr::get(&getContext()));
74 }
75 func->removeAttr("test.insert_results");
76 func.insertResults(indicesToInsert, typesToInsert, attrsToInsert);
77 }
78 }
79 };
80
81 /// This is a test pass for verifying FuncOp's eraseArgument method.
82 struct TestFuncEraseArg
83 : public PassWrapper<TestFuncEraseArg, OperationPass<ModuleOp>> {
getArgument__anon70e73dd40111::TestFuncEraseArg84 StringRef getArgument() const final { return "test-func-erase-arg"; }
getDescription__anon70e73dd40111::TestFuncEraseArg85 StringRef getDescription() const final { return "Test erasing func args."; }
runOnOperation__anon70e73dd40111::TestFuncEraseArg86 void runOnOperation() override {
87 auto module = getOperation();
88
89 for (FuncOp func : module.getOps<FuncOp>()) {
90 SmallVector<unsigned, 4> indicesToErase;
91 for (auto argIndex : llvm::seq<int>(0, func.getNumArguments())) {
92 if (func.getArgAttr(argIndex, "test.erase_this_arg")) {
93 // Push back twice to test that duplicate arg indices are handled
94 // correctly.
95 indicesToErase.push_back(argIndex);
96 indicesToErase.push_back(argIndex);
97 }
98 }
99 // Reverse the order to test that unsorted index lists are handled
100 // correctly.
101 std::reverse(indicesToErase.begin(), indicesToErase.end());
102 func.eraseArguments(indicesToErase);
103 }
104 }
105 };
106
107 /// This is a test pass for verifying FuncOp's eraseResult method.
108 struct TestFuncEraseResult
109 : public PassWrapper<TestFuncEraseResult, OperationPass<ModuleOp>> {
getArgument__anon70e73dd40111::TestFuncEraseResult110 StringRef getArgument() const final { return "test-func-erase-result"; }
getDescription__anon70e73dd40111::TestFuncEraseResult111 StringRef getDescription() const final {
112 return "Test erasing func results.";
113 }
runOnOperation__anon70e73dd40111::TestFuncEraseResult114 void runOnOperation() override {
115 auto module = getOperation();
116
117 for (FuncOp func : module.getOps<FuncOp>()) {
118 SmallVector<unsigned, 4> indicesToErase;
119 for (auto resultIndex : llvm::seq<int>(0, func.getNumResults())) {
120 if (func.getResultAttr(resultIndex, "test.erase_this_result")) {
121 // Push back twice to test that duplicate indices are handled
122 // correctly.
123 indicesToErase.push_back(resultIndex);
124 indicesToErase.push_back(resultIndex);
125 }
126 }
127 // Reverse the order to test that unsorted index lists are handled
128 // correctly.
129 std::reverse(indicesToErase.begin(), indicesToErase.end());
130 func.eraseResults(indicesToErase);
131 }
132 }
133 };
134
135 /// This is a test pass for verifying FuncOp's setType method.
136 struct TestFuncSetType
137 : public PassWrapper<TestFuncSetType, OperationPass<ModuleOp>> {
getArgument__anon70e73dd40111::TestFuncSetType138 StringRef getArgument() const final { return "test-func-set-type"; }
getDescription__anon70e73dd40111::TestFuncSetType139 StringRef getDescription() const final { return "Test FuncOp::setType."; }
runOnOperation__anon70e73dd40111::TestFuncSetType140 void runOnOperation() override {
141 auto module = getOperation();
142 SymbolTable symbolTable(module);
143
144 for (FuncOp func : module.getOps<FuncOp>()) {
145 auto sym = func->getAttrOfType<FlatSymbolRefAttr>("test.set_type_from");
146 if (!sym)
147 continue;
148 func.setType(symbolTable.lookup<FuncOp>(sym.getValue()).getType());
149 }
150 }
151 };
152 } // end anonymous namespace
153
154 namespace mlir {
registerTestFunc()155 void registerTestFunc() {
156 PassRegistration<TestFuncInsertArg>();
157
158 PassRegistration<TestFuncInsertResult>();
159
160 PassRegistration<TestFuncEraseArg>();
161
162 PassRegistration<TestFuncEraseResult>();
163
164 PassRegistration<TestFuncSetType>();
165 }
166 } // namespace mlir
167