1 //===- FunctionSupport.cpp - Utility types for function-like ops ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/IR/FunctionSupport.h"
10 #include "mlir/Support/LLVM.h"
11 #include "llvm/ADT/BitVector.h"
12
13 using namespace mlir;
14
15 /// Helper to call a callback once on each index in the range
16 /// [0, `totalIndices`), *except* for the indices given in `indices`.
17 /// `indices` is allowed to have duplicates and can be in any order.
iterateIndicesExcept(unsigned totalIndices,ArrayRef<unsigned> indices,function_ref<void (unsigned)> callback)18 inline void iterateIndicesExcept(unsigned totalIndices,
19 ArrayRef<unsigned> indices,
20 function_ref<void(unsigned)> callback) {
21 llvm::BitVector skipIndices(totalIndices);
22 for (unsigned i : indices)
23 skipIndices.set(i);
24
25 for (unsigned i = 0; i < totalIndices; ++i)
26 if (!skipIndices.test(i))
27 callback(i);
28 }
29
30 //===----------------------------------------------------------------------===//
31 // Function Arguments and Results.
32 //===----------------------------------------------------------------------===//
33
isEmptyAttrDict(Attribute attr)34 static bool isEmptyAttrDict(Attribute attr) {
35 return attr.cast<DictionaryAttr>().empty();
36 }
37
getArgAttrDict(Operation * op,unsigned index)38 DictionaryAttr mlir::function_like_impl::getArgAttrDict(Operation *op,
39 unsigned index) {
40 ArrayAttr attrs = op->getAttrOfType<ArrayAttr>(getArgDictAttrName());
41 DictionaryAttr argAttrs =
42 attrs ? attrs[index].cast<DictionaryAttr>() : DictionaryAttr();
43 return argAttrs;
44 }
45
getResultAttrDict(Operation * op,unsigned index)46 DictionaryAttr mlir::function_like_impl::getResultAttrDict(Operation *op,
47 unsigned index) {
48 ArrayAttr attrs = op->getAttrOfType<ArrayAttr>(getResultDictAttrName());
49 DictionaryAttr resAttrs =
50 attrs ? attrs[index].cast<DictionaryAttr>() : DictionaryAttr();
51 return resAttrs;
52 }
53
setArgResAttrDict(Operation * op,StringRef attrName,unsigned numTotalIndices,unsigned index,DictionaryAttr attrs)54 void mlir::function_like_impl::detail::setArgResAttrDict(
55 Operation *op, StringRef attrName, unsigned numTotalIndices, unsigned index,
56 DictionaryAttr attrs) {
57 ArrayAttr allAttrs = op->getAttrOfType<ArrayAttr>(attrName);
58 if (!allAttrs) {
59 if (attrs.empty())
60 return;
61
62 // If this attribute is not empty, we need to create a new attribute array.
63 SmallVector<Attribute, 8> newAttrs(numTotalIndices,
64 DictionaryAttr::get(op->getContext()));
65 newAttrs[index] = attrs;
66 op->setAttr(attrName, ArrayAttr::get(op->getContext(), newAttrs));
67 return;
68 }
69 // Check to see if the attribute is different from what we already have.
70 if (allAttrs[index] == attrs)
71 return;
72
73 // If it is, check to see if the attribute array would now contain only empty
74 // dictionaries.
75 ArrayRef<Attribute> rawAttrArray = allAttrs.getValue();
76 if (attrs.empty() &&
77 llvm::all_of(rawAttrArray.take_front(index), isEmptyAttrDict) &&
78 llvm::all_of(rawAttrArray.drop_front(index + 1), isEmptyAttrDict)) {
79 op->removeAttr(attrName);
80 return;
81 }
82
83 // Otherwise, create a new attribute array with the updated dictionary.
84 SmallVector<Attribute, 8> newAttrs(rawAttrArray.begin(), rawAttrArray.end());
85 newAttrs[index] = attrs;
86 op->setAttr(attrName, ArrayAttr::get(op->getContext(), newAttrs));
87 }
88
89 /// Set all of the argument or result attribute dictionaries for a function.
setAllArgResAttrDicts(Operation * op,StringRef attrName,ArrayRef<Attribute> attrs)90 static void setAllArgResAttrDicts(Operation *op, StringRef attrName,
91 ArrayRef<Attribute> attrs) {
92 if (llvm::all_of(attrs, isEmptyAttrDict))
93 op->removeAttr(attrName);
94 else
95 op->setAttr(attrName, ArrayAttr::get(op->getContext(), attrs));
96 }
97
setAllArgAttrDicts(Operation * op,ArrayRef<DictionaryAttr> attrs)98 void mlir::function_like_impl::setAllArgAttrDicts(
99 Operation *op, ArrayRef<DictionaryAttr> attrs) {
100 setAllArgAttrDicts(op, ArrayRef<Attribute>(attrs.data(), attrs.size()));
101 }
setAllArgAttrDicts(Operation * op,ArrayRef<Attribute> attrs)102 void mlir::function_like_impl::setAllArgAttrDicts(Operation *op,
103 ArrayRef<Attribute> attrs) {
104 auto wrappedAttrs = llvm::map_range(attrs, [op](Attribute attr) -> Attribute {
105 return !attr ? DictionaryAttr::get(op->getContext()) : attr;
106 });
107 setAllArgResAttrDicts(op, getArgDictAttrName(),
108 llvm::to_vector<8>(wrappedAttrs));
109 }
110
setAllResultAttrDicts(Operation * op,ArrayRef<DictionaryAttr> attrs)111 void mlir::function_like_impl::setAllResultAttrDicts(
112 Operation *op, ArrayRef<DictionaryAttr> attrs) {
113 setAllResultAttrDicts(op, ArrayRef<Attribute>(attrs.data(), attrs.size()));
114 }
setAllResultAttrDicts(Operation * op,ArrayRef<Attribute> attrs)115 void mlir::function_like_impl::setAllResultAttrDicts(
116 Operation *op, ArrayRef<Attribute> attrs) {
117 auto wrappedAttrs = llvm::map_range(attrs, [op](Attribute attr) -> Attribute {
118 return !attr ? DictionaryAttr::get(op->getContext()) : attr;
119 });
120 setAllArgResAttrDicts(op, getResultDictAttrName(),
121 llvm::to_vector<8>(wrappedAttrs));
122 }
123
insertFunctionArguments(Operation * op,ArrayRef<unsigned> argIndices,TypeRange argTypes,ArrayRef<DictionaryAttr> argAttrs,ArrayRef<Optional<Location>> argLocs,unsigned originalNumArgs,Type newType)124 void mlir::function_like_impl::insertFunctionArguments(
125 Operation *op, ArrayRef<unsigned> argIndices, TypeRange argTypes,
126 ArrayRef<DictionaryAttr> argAttrs, ArrayRef<Optional<Location>> argLocs,
127 unsigned originalNumArgs, Type newType) {
128 assert(argIndices.size() == argTypes.size());
129 assert(argIndices.size() == argAttrs.size() || argAttrs.empty());
130 assert(argIndices.size() == argLocs.size() || argLocs.empty());
131 if (argIndices.empty())
132 return;
133
134 // There are 3 things that need to be updated:
135 // - Function type.
136 // - Arg attrs.
137 // - Block arguments of entry block.
138 Block &entry = op->getRegion(0).front();
139
140 // Update the argument attributes of the function.
141 auto oldArgAttrs = op->getAttrOfType<ArrayAttr>(getArgDictAttrName());
142 if (oldArgAttrs || !argAttrs.empty()) {
143 SmallVector<DictionaryAttr, 4> newArgAttrs;
144 newArgAttrs.reserve(originalNumArgs + argIndices.size());
145 unsigned oldIdx = 0;
146 auto migrate = [&](unsigned untilIdx) {
147 if (!oldArgAttrs) {
148 newArgAttrs.resize(newArgAttrs.size() + untilIdx - oldIdx);
149 } else {
150 auto oldArgAttrRange = oldArgAttrs.getAsRange<DictionaryAttr>();
151 newArgAttrs.append(oldArgAttrRange.begin() + oldIdx,
152 oldArgAttrRange.begin() + untilIdx);
153 }
154 oldIdx = untilIdx;
155 };
156 for (unsigned i = 0, e = argIndices.size(); i < e; ++i) {
157 migrate(argIndices[i]);
158 newArgAttrs.push_back(argAttrs.empty() ? DictionaryAttr{} : argAttrs[i]);
159 }
160 migrate(originalNumArgs);
161 setAllArgAttrDicts(op, newArgAttrs);
162 }
163
164 // Update the function type and any entry block arguments.
165 op->setAttr(getTypeAttrName(), TypeAttr::get(newType));
166 for (unsigned i = 0, e = argIndices.size(); i < e; ++i)
167 entry.insertArgument(argIndices[i], argTypes[i],
168 argLocs.empty() ? Optional<Location>{} : argLocs[i]);
169 }
170
insertFunctionResults(Operation * op,ArrayRef<unsigned> resultIndices,TypeRange resultTypes,ArrayRef<DictionaryAttr> resultAttrs,unsigned originalNumResults,Type newType)171 void mlir::function_like_impl::insertFunctionResults(
172 Operation *op, ArrayRef<unsigned> resultIndices, TypeRange resultTypes,
173 ArrayRef<DictionaryAttr> resultAttrs, unsigned originalNumResults,
174 Type newType) {
175 assert(resultIndices.size() == resultTypes.size());
176 assert(resultIndices.size() == resultAttrs.size() || resultAttrs.empty());
177 if (resultIndices.empty())
178 return;
179
180 // There are 2 things that need to be updated:
181 // - Function type.
182 // - Result attrs.
183
184 // Update the result attributes of the function.
185 auto oldResultAttrs = op->getAttrOfType<ArrayAttr>(getResultDictAttrName());
186 if (oldResultAttrs || !resultAttrs.empty()) {
187 SmallVector<DictionaryAttr, 4> newResultAttrs;
188 newResultAttrs.reserve(originalNumResults + resultIndices.size());
189 unsigned oldIdx = 0;
190 auto migrate = [&](unsigned untilIdx) {
191 if (!oldResultAttrs) {
192 newResultAttrs.resize(newResultAttrs.size() + untilIdx - oldIdx);
193 } else {
194 auto oldResultAttrsRange = oldResultAttrs.getAsRange<DictionaryAttr>();
195 newResultAttrs.append(oldResultAttrsRange.begin() + oldIdx,
196 oldResultAttrsRange.begin() + untilIdx);
197 }
198 oldIdx = untilIdx;
199 };
200 for (unsigned i = 0, e = resultIndices.size(); i < e; ++i) {
201 migrate(resultIndices[i]);
202 newResultAttrs.push_back(resultAttrs.empty() ? DictionaryAttr{}
203 : resultAttrs[i]);
204 }
205 migrate(originalNumResults);
206 setAllResultAttrDicts(op, newResultAttrs);
207 }
208
209 // Update the function type.
210 op->setAttr(getTypeAttrName(), TypeAttr::get(newType));
211 }
212
eraseFunctionArguments(Operation * op,ArrayRef<unsigned> argIndices,unsigned originalNumArgs,Type newType)213 void mlir::function_like_impl::eraseFunctionArguments(
214 Operation *op, ArrayRef<unsigned> argIndices, unsigned originalNumArgs,
215 Type newType) {
216 // There are 3 things that need to be updated:
217 // - Function type.
218 // - Arg attrs.
219 // - Block arguments of entry block.
220 Block &entry = op->getRegion(0).front();
221
222 // Update the argument attributes of the function.
223 if (auto argAttrs = op->getAttrOfType<ArrayAttr>(getArgDictAttrName())) {
224 SmallVector<DictionaryAttr, 4> newArgAttrs;
225 newArgAttrs.reserve(argAttrs.size());
226 iterateIndicesExcept(originalNumArgs, argIndices, [&](unsigned i) {
227 newArgAttrs.emplace_back(argAttrs[i].cast<DictionaryAttr>());
228 });
229 setAllArgAttrDicts(op, newArgAttrs);
230 }
231
232 // Update the function type and any entry block arguments.
233 op->setAttr(getTypeAttrName(), TypeAttr::get(newType));
234 entry.eraseArguments(argIndices);
235 }
236
eraseFunctionResults(Operation * op,ArrayRef<unsigned> resultIndices,unsigned originalNumResults,Type newType)237 void mlir::function_like_impl::eraseFunctionResults(
238 Operation *op, ArrayRef<unsigned> resultIndices,
239 unsigned originalNumResults, Type newType) {
240 // There are 2 things that need to be updated:
241 // - Function type.
242 // - Result attrs.
243
244 // Update the result attributes of the function.
245 if (auto resAttrs = op->getAttrOfType<ArrayAttr>(getResultDictAttrName())) {
246 SmallVector<DictionaryAttr, 4> newResultAttrs;
247 newResultAttrs.reserve(resAttrs.size());
248 iterateIndicesExcept(originalNumResults, resultIndices, [&](unsigned i) {
249 newResultAttrs.emplace_back(resAttrs[i].cast<DictionaryAttr>());
250 });
251 setAllResultAttrDicts(op, newResultAttrs);
252 }
253
254 // Update the function type.
255 op->setAttr(getTypeAttrName(), TypeAttr::get(newType));
256 }
257
258 //===----------------------------------------------------------------------===//
259 // Function type signature.
260 //===----------------------------------------------------------------------===//
261
getFunctionType(Operation * op)262 FunctionType mlir::function_like_impl::getFunctionType(Operation *op) {
263 assert(op->hasTrait<OpTrait::FunctionLike>());
264 return op->getAttrOfType<TypeAttr>(getTypeAttrName())
265 .getValue()
266 .cast<FunctionType>();
267 }
268
setFunctionType(Operation * op,FunctionType newType)269 void mlir::function_like_impl::setFunctionType(Operation *op,
270 FunctionType newType) {
271 assert(op->hasTrait<OpTrait::FunctionLike>());
272 FunctionType oldType = getFunctionType(op);
273 op->setAttr(getTypeAttrName(), TypeAttr::get(newType));
274
275 // Functor used to update the argument and result attributes of the function.
276 auto updateAttrFn = [&](StringRef attrName, unsigned oldCount,
277 unsigned newCount, auto setAttrFn) {
278 if (oldCount == newCount)
279 return;
280 // The new type has no arguments/results, just drop the attribute.
281 if (newCount == 0) {
282 op->removeAttr(attrName);
283 return;
284 }
285 ArrayAttr attrs = op->getAttrOfType<ArrayAttr>(attrName);
286 if (!attrs)
287 return;
288
289 // The new type has less arguments/results, take the first N attributes.
290 if (newCount < oldCount)
291 return setAttrFn(op, attrs.getValue().take_front(newCount));
292
293 // Otherwise, the new type has more arguments/results. Initialize the new
294 // arguments/results with empty attributes.
295 SmallVector<Attribute> newAttrs(attrs.begin(), attrs.end());
296 newAttrs.resize(newCount);
297 setAttrFn(op, newAttrs);
298 };
299
300 // Update the argument and result attributes.
301 updateAttrFn(function_like_impl::getArgDictAttrName(), oldType.getNumInputs(),
302 newType.getNumInputs(), [&](Operation *op, auto &&attrs) {
303 setAllArgAttrDicts(op, attrs);
304 });
305 updateAttrFn(
306 function_like_impl::getResultDictAttrName(), oldType.getNumResults(),
307 newType.getNumResults(),
308 [&](Operation *op, auto &&attrs) { setAllResultAttrDicts(op, attrs); });
309 }
310
311 //===----------------------------------------------------------------------===//
312 // Function body.
313 //===----------------------------------------------------------------------===//
314
getFunctionBody(Operation * op)315 Region &mlir::function_like_impl::getFunctionBody(Operation *op) {
316 assert(op->hasTrait<OpTrait::FunctionLike>());
317 return op->getRegion(0);
318 }
319