1 //===- FunctionSupport.cpp - Utility types for function-like ops ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/IR/FunctionSupport.h"
10 #include "mlir/Support/LLVM.h"
11 #include "llvm/ADT/BitVector.h"
12 
13 using namespace mlir;
14 
15 /// Helper to call a callback once on each index in the range
16 /// [0, `totalIndices`), *except* for the indices given in `indices`.
17 /// `indices` is allowed to have duplicates and can be in any order.
iterateIndicesExcept(unsigned totalIndices,ArrayRef<unsigned> indices,function_ref<void (unsigned)> callback)18 inline void iterateIndicesExcept(unsigned totalIndices,
19                                  ArrayRef<unsigned> indices,
20                                  function_ref<void(unsigned)> callback) {
21   llvm::BitVector skipIndices(totalIndices);
22   for (unsigned i : indices)
23     skipIndices.set(i);
24 
25   for (unsigned i = 0; i < totalIndices; ++i)
26     if (!skipIndices.test(i))
27       callback(i);
28 }
29 
30 //===----------------------------------------------------------------------===//
31 // Function Arguments and Results.
32 //===----------------------------------------------------------------------===//
33 
isEmptyAttrDict(Attribute attr)34 static bool isEmptyAttrDict(Attribute attr) {
35   return attr.cast<DictionaryAttr>().empty();
36 }
37 
getArgAttrDict(Operation * op,unsigned index)38 DictionaryAttr mlir::function_like_impl::getArgAttrDict(Operation *op,
39                                                         unsigned index) {
40   ArrayAttr attrs = op->getAttrOfType<ArrayAttr>(getArgDictAttrName());
41   DictionaryAttr argAttrs =
42       attrs ? attrs[index].cast<DictionaryAttr>() : DictionaryAttr();
43   return argAttrs;
44 }
45 
getResultAttrDict(Operation * op,unsigned index)46 DictionaryAttr mlir::function_like_impl::getResultAttrDict(Operation *op,
47                                                            unsigned index) {
48   ArrayAttr attrs = op->getAttrOfType<ArrayAttr>(getResultDictAttrName());
49   DictionaryAttr resAttrs =
50       attrs ? attrs[index].cast<DictionaryAttr>() : DictionaryAttr();
51   return resAttrs;
52 }
53 
setArgResAttrDict(Operation * op,StringRef attrName,unsigned numTotalIndices,unsigned index,DictionaryAttr attrs)54 void mlir::function_like_impl::detail::setArgResAttrDict(
55     Operation *op, StringRef attrName, unsigned numTotalIndices, unsigned index,
56     DictionaryAttr attrs) {
57   ArrayAttr allAttrs = op->getAttrOfType<ArrayAttr>(attrName);
58   if (!allAttrs) {
59     if (attrs.empty())
60       return;
61 
62     // If this attribute is not empty, we need to create a new attribute array.
63     SmallVector<Attribute, 8> newAttrs(numTotalIndices,
64                                        DictionaryAttr::get(op->getContext()));
65     newAttrs[index] = attrs;
66     op->setAttr(attrName, ArrayAttr::get(op->getContext(), newAttrs));
67     return;
68   }
69   // Check to see if the attribute is different from what we already have.
70   if (allAttrs[index] == attrs)
71     return;
72 
73   // If it is, check to see if the attribute array would now contain only empty
74   // dictionaries.
75   ArrayRef<Attribute> rawAttrArray = allAttrs.getValue();
76   if (attrs.empty() &&
77       llvm::all_of(rawAttrArray.take_front(index), isEmptyAttrDict) &&
78       llvm::all_of(rawAttrArray.drop_front(index + 1), isEmptyAttrDict)) {
79     op->removeAttr(attrName);
80     return;
81   }
82 
83   // Otherwise, create a new attribute array with the updated dictionary.
84   SmallVector<Attribute, 8> newAttrs(rawAttrArray.begin(), rawAttrArray.end());
85   newAttrs[index] = attrs;
86   op->setAttr(attrName, ArrayAttr::get(op->getContext(), newAttrs));
87 }
88 
89 /// Set all of the argument or result attribute dictionaries for a function.
setAllArgResAttrDicts(Operation * op,StringRef attrName,ArrayRef<Attribute> attrs)90 static void setAllArgResAttrDicts(Operation *op, StringRef attrName,
91                                   ArrayRef<Attribute> attrs) {
92   if (llvm::all_of(attrs, isEmptyAttrDict))
93     op->removeAttr(attrName);
94   else
95     op->setAttr(attrName, ArrayAttr::get(op->getContext(), attrs));
96 }
97 
setAllArgAttrDicts(Operation * op,ArrayRef<DictionaryAttr> attrs)98 void mlir::function_like_impl::setAllArgAttrDicts(
99     Operation *op, ArrayRef<DictionaryAttr> attrs) {
100   setAllArgAttrDicts(op, ArrayRef<Attribute>(attrs.data(), attrs.size()));
101 }
setAllArgAttrDicts(Operation * op,ArrayRef<Attribute> attrs)102 void mlir::function_like_impl::setAllArgAttrDicts(Operation *op,
103                                                   ArrayRef<Attribute> attrs) {
104   auto wrappedAttrs = llvm::map_range(attrs, [op](Attribute attr) -> Attribute {
105     return !attr ? DictionaryAttr::get(op->getContext()) : attr;
106   });
107   setAllArgResAttrDicts(op, getArgDictAttrName(),
108                         llvm::to_vector<8>(wrappedAttrs));
109 }
110 
setAllResultAttrDicts(Operation * op,ArrayRef<DictionaryAttr> attrs)111 void mlir::function_like_impl::setAllResultAttrDicts(
112     Operation *op, ArrayRef<DictionaryAttr> attrs) {
113   setAllResultAttrDicts(op, ArrayRef<Attribute>(attrs.data(), attrs.size()));
114 }
setAllResultAttrDicts(Operation * op,ArrayRef<Attribute> attrs)115 void mlir::function_like_impl::setAllResultAttrDicts(
116     Operation *op, ArrayRef<Attribute> attrs) {
117   auto wrappedAttrs = llvm::map_range(attrs, [op](Attribute attr) -> Attribute {
118     return !attr ? DictionaryAttr::get(op->getContext()) : attr;
119   });
120   setAllArgResAttrDicts(op, getResultDictAttrName(),
121                         llvm::to_vector<8>(wrappedAttrs));
122 }
123 
insertFunctionArguments(Operation * op,ArrayRef<unsigned> argIndices,TypeRange argTypes,ArrayRef<DictionaryAttr> argAttrs,ArrayRef<Optional<Location>> argLocs,unsigned originalNumArgs,Type newType)124 void mlir::function_like_impl::insertFunctionArguments(
125     Operation *op, ArrayRef<unsigned> argIndices, TypeRange argTypes,
126     ArrayRef<DictionaryAttr> argAttrs, ArrayRef<Optional<Location>> argLocs,
127     unsigned originalNumArgs, Type newType) {
128   assert(argIndices.size() == argTypes.size());
129   assert(argIndices.size() == argAttrs.size() || argAttrs.empty());
130   assert(argIndices.size() == argLocs.size() || argLocs.empty());
131   if (argIndices.empty())
132     return;
133 
134   // There are 3 things that need to be updated:
135   // - Function type.
136   // - Arg attrs.
137   // - Block arguments of entry block.
138   Block &entry = op->getRegion(0).front();
139 
140   // Update the argument attributes of the function.
141   auto oldArgAttrs = op->getAttrOfType<ArrayAttr>(getArgDictAttrName());
142   if (oldArgAttrs || !argAttrs.empty()) {
143     SmallVector<DictionaryAttr, 4> newArgAttrs;
144     newArgAttrs.reserve(originalNumArgs + argIndices.size());
145     unsigned oldIdx = 0;
146     auto migrate = [&](unsigned untilIdx) {
147       if (!oldArgAttrs) {
148         newArgAttrs.resize(newArgAttrs.size() + untilIdx - oldIdx);
149       } else {
150         auto oldArgAttrRange = oldArgAttrs.getAsRange<DictionaryAttr>();
151         newArgAttrs.append(oldArgAttrRange.begin() + oldIdx,
152                            oldArgAttrRange.begin() + untilIdx);
153       }
154       oldIdx = untilIdx;
155     };
156     for (unsigned i = 0, e = argIndices.size(); i < e; ++i) {
157       migrate(argIndices[i]);
158       newArgAttrs.push_back(argAttrs.empty() ? DictionaryAttr{} : argAttrs[i]);
159     }
160     migrate(originalNumArgs);
161     setAllArgAttrDicts(op, newArgAttrs);
162   }
163 
164   // Update the function type and any entry block arguments.
165   op->setAttr(getTypeAttrName(), TypeAttr::get(newType));
166   for (unsigned i = 0, e = argIndices.size(); i < e; ++i)
167     entry.insertArgument(argIndices[i], argTypes[i],
168                          argLocs.empty() ? Optional<Location>{} : argLocs[i]);
169 }
170 
insertFunctionResults(Operation * op,ArrayRef<unsigned> resultIndices,TypeRange resultTypes,ArrayRef<DictionaryAttr> resultAttrs,unsigned originalNumResults,Type newType)171 void mlir::function_like_impl::insertFunctionResults(
172     Operation *op, ArrayRef<unsigned> resultIndices, TypeRange resultTypes,
173     ArrayRef<DictionaryAttr> resultAttrs, unsigned originalNumResults,
174     Type newType) {
175   assert(resultIndices.size() == resultTypes.size());
176   assert(resultIndices.size() == resultAttrs.size() || resultAttrs.empty());
177   if (resultIndices.empty())
178     return;
179 
180   // There are 2 things that need to be updated:
181   // - Function type.
182   // - Result attrs.
183 
184   // Update the result attributes of the function.
185   auto oldResultAttrs = op->getAttrOfType<ArrayAttr>(getResultDictAttrName());
186   if (oldResultAttrs || !resultAttrs.empty()) {
187     SmallVector<DictionaryAttr, 4> newResultAttrs;
188     newResultAttrs.reserve(originalNumResults + resultIndices.size());
189     unsigned oldIdx = 0;
190     auto migrate = [&](unsigned untilIdx) {
191       if (!oldResultAttrs) {
192         newResultAttrs.resize(newResultAttrs.size() + untilIdx - oldIdx);
193       } else {
194         auto oldResultAttrsRange = oldResultAttrs.getAsRange<DictionaryAttr>();
195         newResultAttrs.append(oldResultAttrsRange.begin() + oldIdx,
196                               oldResultAttrsRange.begin() + untilIdx);
197       }
198       oldIdx = untilIdx;
199     };
200     for (unsigned i = 0, e = resultIndices.size(); i < e; ++i) {
201       migrate(resultIndices[i]);
202       newResultAttrs.push_back(resultAttrs.empty() ? DictionaryAttr{}
203                                                    : resultAttrs[i]);
204     }
205     migrate(originalNumResults);
206     setAllResultAttrDicts(op, newResultAttrs);
207   }
208 
209   // Update the function type.
210   op->setAttr(getTypeAttrName(), TypeAttr::get(newType));
211 }
212 
eraseFunctionArguments(Operation * op,ArrayRef<unsigned> argIndices,unsigned originalNumArgs,Type newType)213 void mlir::function_like_impl::eraseFunctionArguments(
214     Operation *op, ArrayRef<unsigned> argIndices, unsigned originalNumArgs,
215     Type newType) {
216   // There are 3 things that need to be updated:
217   // - Function type.
218   // - Arg attrs.
219   // - Block arguments of entry block.
220   Block &entry = op->getRegion(0).front();
221 
222   // Update the argument attributes of the function.
223   if (auto argAttrs = op->getAttrOfType<ArrayAttr>(getArgDictAttrName())) {
224     SmallVector<DictionaryAttr, 4> newArgAttrs;
225     newArgAttrs.reserve(argAttrs.size());
226     iterateIndicesExcept(originalNumArgs, argIndices, [&](unsigned i) {
227       newArgAttrs.emplace_back(argAttrs[i].cast<DictionaryAttr>());
228     });
229     setAllArgAttrDicts(op, newArgAttrs);
230   }
231 
232   // Update the function type and any entry block arguments.
233   op->setAttr(getTypeAttrName(), TypeAttr::get(newType));
234   entry.eraseArguments(argIndices);
235 }
236 
eraseFunctionResults(Operation * op,ArrayRef<unsigned> resultIndices,unsigned originalNumResults,Type newType)237 void mlir::function_like_impl::eraseFunctionResults(
238     Operation *op, ArrayRef<unsigned> resultIndices,
239     unsigned originalNumResults, Type newType) {
240   // There are 2 things that need to be updated:
241   // - Function type.
242   // - Result attrs.
243 
244   // Update the result attributes of the function.
245   if (auto resAttrs = op->getAttrOfType<ArrayAttr>(getResultDictAttrName())) {
246     SmallVector<DictionaryAttr, 4> newResultAttrs;
247     newResultAttrs.reserve(resAttrs.size());
248     iterateIndicesExcept(originalNumResults, resultIndices, [&](unsigned i) {
249       newResultAttrs.emplace_back(resAttrs[i].cast<DictionaryAttr>());
250     });
251     setAllResultAttrDicts(op, newResultAttrs);
252   }
253 
254   // Update the function type.
255   op->setAttr(getTypeAttrName(), TypeAttr::get(newType));
256 }
257 
258 //===----------------------------------------------------------------------===//
259 // Function type signature.
260 //===----------------------------------------------------------------------===//
261 
getFunctionType(Operation * op)262 FunctionType mlir::function_like_impl::getFunctionType(Operation *op) {
263   assert(op->hasTrait<OpTrait::FunctionLike>());
264   return op->getAttrOfType<TypeAttr>(getTypeAttrName())
265       .getValue()
266       .cast<FunctionType>();
267 }
268 
setFunctionType(Operation * op,FunctionType newType)269 void mlir::function_like_impl::setFunctionType(Operation *op,
270                                                FunctionType newType) {
271   assert(op->hasTrait<OpTrait::FunctionLike>());
272   FunctionType oldType = getFunctionType(op);
273   op->setAttr(getTypeAttrName(), TypeAttr::get(newType));
274 
275   // Functor used to update the argument and result attributes of the function.
276   auto updateAttrFn = [&](StringRef attrName, unsigned oldCount,
277                           unsigned newCount, auto setAttrFn) {
278     if (oldCount == newCount)
279       return;
280     // The new type has no arguments/results, just drop the attribute.
281     if (newCount == 0) {
282       op->removeAttr(attrName);
283       return;
284     }
285     ArrayAttr attrs = op->getAttrOfType<ArrayAttr>(attrName);
286     if (!attrs)
287       return;
288 
289     // The new type has less arguments/results, take the first N attributes.
290     if (newCount < oldCount)
291       return setAttrFn(op, attrs.getValue().take_front(newCount));
292 
293     // Otherwise, the new type has more arguments/results. Initialize the new
294     // arguments/results with empty attributes.
295     SmallVector<Attribute> newAttrs(attrs.begin(), attrs.end());
296     newAttrs.resize(newCount);
297     setAttrFn(op, newAttrs);
298   };
299 
300   // Update the argument and result attributes.
301   updateAttrFn(function_like_impl::getArgDictAttrName(), oldType.getNumInputs(),
302                newType.getNumInputs(), [&](Operation *op, auto &&attrs) {
303                  setAllArgAttrDicts(op, attrs);
304                });
305   updateAttrFn(
306       function_like_impl::getResultDictAttrName(), oldType.getNumResults(),
307       newType.getNumResults(),
308       [&](Operation *op, auto &&attrs) { setAllResultAttrDicts(op, attrs); });
309 }
310 
311 //===----------------------------------------------------------------------===//
312 // Function body.
313 //===----------------------------------------------------------------------===//
314 
getFunctionBody(Operation * op)315 Region &mlir::function_like_impl::getFunctionBody(Operation *op) {
316   assert(op->hasTrait<OpTrait::FunctionLike>());
317   return op->getRegion(0);
318 }
319